import torch
import torch.distributed as dist
import torch.nn.functional as F

try:
    import fused_mix_prec_layer_norm_cuda
except:
    fused_mix_prec_layer_norm_cuda = None

try:
    import fused_weight_gradient_mlp_cuda

    _grad_accum_fusion_available = True
except ImportError:
    _grad_accum_fusion_available = False

from colossalai.shardformer.layer._operation import _reduce
import numpy as np

def reduce_backward_LR(input_, process_group):
    return _ReduceBackward_LR.apply(input_, process_group)
    

class _ReduceBackward_LR(torch.autograd.Function):
    """
    All-reduce the input from the model parallel region.

    Args:
        input_: input matrix.
        parallel_mode: parallel mode.
    with Low Rank Compression.        
    """
    rng = np.random.RandomState(1234)

    @staticmethod
    def orthogonalize(matrix, eps=torch.tensor(1e-8)):
        n, m = matrix.shape
        for i in range(m):
            # Normalize the i'th column
            col = matrix[:, i : i + 1]
            col /= torch.sqrt(torch.sum(col ** 2)) + eps
            # Project it on the rest and remove it
            if i + 1 < m:
                rest = matrix[:, i + 1 :]
                # rest -= torch.matmul(col.t(), rest) * col
                rest -= torch.sum(col * rest, dim=0) * col

    @staticmethod
    def forward(ctx, input_, process_group):
        ctx.process_group = process_group
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        world_size = dist.get_world_size(ctx.process_group)
        tensor_shape = False
        if len(grad_output.shape) > 2:
            tensor_shape = True              
            b,s,f = grad_output.shape
            grad_output = grad_output.view(b*s,f)
        rank = grad_output.size(-1) * grad_output.size(-2) // (10 * (grad_output.size(-1) + grad_output.size(-2))) # 10 for 10% (1/10)

        torch.manual_seed(_ReduceBackward_LR.rng.randint(1_000_000_000))
        Q = torch.randn(grad_output.size(-1), rank, device=grad_output.device, dtype=grad_output.dtype)
        # _ReduceBackward_LR.orthogonalize(Q)    # This can be skipped (Option)

        P = grad_output.matmul(Q)
        _reduce(P, ctx.process_group)
        P /= world_size
        _ReduceBackward_LR.orthogonalize(P)
        Q = grad_output.t().matmul(P)
        _reduce(Q, ctx.process_group)

        grad_output = P.matmul(Q.t())
        if tensor_shape:
            grad_output = grad_output.view(b,s,f)
        return grad_output, None    